package ai.onnxruntime;

import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtUtil;
import ai.onnxruntime.TensorInfo;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Arrays;

/* loaded from: classes.dex */
public final class OnnxSparseTensor extends OnnxTensorLike {
    private final Buffer indices;
    private final LongBuffer innerIndices;
    private final SparseTensorType sparseTensorType;
    private final Buffer values;

    /* renamed from: ai.onnxruntime.OnnxSparseTensor$1, reason: invalid class name */
    /* loaded from: classes.dex */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$onnxruntime$OnnxJavaType;
        static final /* synthetic */ int[] $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType;

        static {
            int[] iArr = new int[OnnxJavaType.values().length];
            $SwitchMap$ai$onnxruntime$OnnxJavaType = iArr;
            try {
                iArr[OnnxJavaType.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT16.ordinal()] = 3;
            } catch (NoSuchFieldError unused3) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT32.ordinal()] = 4;
            } catch (NoSuchFieldError unused4) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT64.ordinal()] = 5;
            } catch (NoSuchFieldError unused5) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.BOOL.ordinal()] = 6;
            } catch (NoSuchFieldError unused6) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT8.ordinal()] = 7;
            } catch (NoSuchFieldError unused7) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.UINT8.ordinal()] = 8;
            } catch (NoSuchFieldError unused8) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.STRING.ordinal()] = 9;
            } catch (NoSuchFieldError unused9) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.UNKNOWN.ordinal()] = 10;
            } catch (NoSuchFieldError unused10) {
            }
            int[] iArr2 = new int[SparseTensorType.values().length];
            $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType = iArr2;
            try {
                iArr2[SparseTensorType.COO.ordinal()] = 1;
            } catch (NoSuchFieldError unused11) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[SparseTensorType.BLOCK_SPARSE.ordinal()] = 2;
            } catch (NoSuchFieldError unused12) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[SparseTensorType.CSRC.ordinal()] = 3;
            } catch (NoSuchFieldError unused13) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[SparseTensorType.UNDEFINED.ordinal()] = 4;
            } catch (NoSuchFieldError unused14) {
            }
        }
    }

    /* loaded from: classes.dex */
    public static final class BlockSparseTensor extends SparseTensor<IntBuffer> {
        public BlockSparseTensor(IntBuffer intBuffer, long[] jArr, Buffer buffer, long[] jArr2, long[] jArr3, OnnxJavaType onnxJavaType, long j2) {
            super(intBuffer, jArr, buffer, jArr2, jArr3, onnxJavaType, j2);
            if (OrtUtil.elementCount(jArr2) != j2) {
                throw new IllegalArgumentException("Expected " + j2 + " entries in the data shape, found " + Arrays.toString(jArr2));
            }
            if (j2 != buffer.remaining()) {
                throw new IllegalArgumentException("Expected " + j2 + " elements in the data buffer, found " + buffer.remaining());
            }
            if (OrtUtil.elementCount(jArr) != intBuffer.remaining()) {
                throw new IllegalArgumentException("Expected " + OrtUtil.elementCount(jArr) + " elements in the indices buffer, found " + intBuffer.remaining());
            }
            if (jArr2.length < 3) {
                throw new IllegalArgumentException("Expected [numBlocks, blockSize, blockSize] or larger, but data shape was " + Arrays.toString(jArr2));
            }
            if (jArr.length >= 2) {
                return;
            }
            throw new IllegalArgumentException("Expected [numBlocks, co-ordinates] or larger, but indices shape was " + Arrays.toString(jArr));
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT32;
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public SparseTensorType getSparsityType() {
            return SparseTensorType.BLOCK_SPARSE;
        }
    }

    /* loaded from: classes.dex */
    public static final class COOTensor extends SparseTensor<LongBuffer> {
        public COOTensor(LongBuffer longBuffer, long[] jArr, Buffer buffer, long[] jArr2, OnnxJavaType onnxJavaType, long j2) {
            super(longBuffer, jArr, buffer, new long[]{j2}, jArr2, onnxJavaType, j2);
            if (jArr.length > 2 || jArr.length == 0 || jArr[0] != j2) {
                throw new IllegalArgumentException("Invalid indices shape, expected [numNonZero, dimension] or [numNonZero] found " + Arrays.toString(jArr));
            }
            long elementCount = OrtUtil.elementCount(jArr);
            if (elementCount != longBuffer.remaining()) {
                throw new IllegalArgumentException("Unexpected number of indices found in buffer, expected " + elementCount + " found " + longBuffer.remaining());
            }
            if (buffer.remaining() == j2) {
                return;
            }
            throw new IllegalArgumentException("Expected data.remaining() - " + buffer.remaining() + " to equal numNonZero - " + j2);
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT64;
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public SparseTensorType getSparsityType() {
            return SparseTensorType.COO;
        }
    }

    /* loaded from: classes.dex */
    public static final class CSRCTensor extends SparseTensor<LongBuffer> {
        private final LongBuffer innerIndices;

        public CSRCTensor(LongBuffer longBuffer, LongBuffer longBuffer2, Buffer buffer, long[] jArr, OnnxJavaType onnxJavaType, long j2) {
            super(longBuffer, new long[]{longBuffer.remaining()}, buffer, new long[]{j2}, jArr, onnxJavaType, j2);
            this.innerIndices = longBuffer2;
            long j3 = jArr[0] + 1;
            if (longBuffer.remaining() != j3) {
                throw new IllegalArgumentException("Outer indices should be equal to the number of rows + 1 in the dense shape, found " + longBuffer.remaining() + ", expected " + j3);
            }
            if (longBuffer2.remaining() == j2) {
                return;
            }
            throw new IllegalArgumentException("Inner indices should be equal to the number of non-zero elements, found " + longBuffer2.remaining() + ", expected " + j2);
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT64;
        }

        public LongBuffer getInnerIndices() {
            return this.innerIndices;
        }

        public long[] getInnerIndicesShape() {
            return new long[]{this.innerIndices.remaining()};
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public SparseTensorType getSparsityType() {
            return SparseTensorType.CSRC;
        }
    }

    /* loaded from: classes.dex */
    public static abstract class SparseTensor<T extends Buffer> {
        private final long[] denseShape;
        final T indices;
        private final long[] indicesShape;
        private final long numNonZero;
        private final OnnxJavaType type;
        final Buffer values;
        private final long[] valuesShape;

        public SparseTensor(T t5, long[] jArr, Buffer buffer, long[] jArr2, long[] jArr3, OnnxJavaType onnxJavaType, long j2) {
            this.indices = t5;
            this.indicesShape = jArr;
            this.values = buffer;
            this.valuesShape = jArr2;
            this.denseShape = jArr3;
            this.type = onnxJavaType;
            this.numNonZero = j2;
            if (buffer.remaining() == j2) {
                if (onnxJavaType == OnnxJavaType.STRING) {
                    throw new IllegalArgumentException("String SparseTensors are not supported.");
                }
                return;
            }
            throw new IllegalArgumentException("Expected numNonZero and data.remaining to be equal, found " + j2 + " and " + buffer.remaining() + " respectively");
        }

        public long[] getDenseShape() {
            return this.denseShape;
        }

        public T getIndices() {
            return this.indices;
        }

        public long[] getIndicesShape() {
            return this.indicesShape;
        }

        public abstract OnnxJavaType getIndicesType();

        public long getNumNonZeroElements() {
            return this.numNonZero;
        }

        public abstract SparseTensorType getSparsityType();

        public OnnxJavaType getType() {
            return this.type;
        }

        public Buffer getValues() {
            return this.values;
        }

        public long[] getValuesShape() {
            return this.valuesShape;
        }
    }

    /* loaded from: classes.dex */
    public enum SparseTensorType {
        UNDEFINED(0),
        COO(1),
        CSRC(2),
        BLOCK_SPARSE(4);

        private static final SparseTensorType[] values;
        public final int value;

        static {
            SparseTensorType sparseTensorType = UNDEFINED;
            SparseTensorType sparseTensorType2 = COO;
            SparseTensorType sparseTensorType3 = CSRC;
            SparseTensorType sparseTensorType4 = BLOCK_SPARSE;
            values = r7;
            SparseTensorType[] sparseTensorTypeArr = {sparseTensorType, sparseTensorType2, sparseTensorType3, sparseTensorType, sparseTensorType4};
        }

        SparseTensorType(int i3) {
            this.value = i3;
        }

        public static SparseTensorType mapFromInt(int i3) {
            if (i3 > 0) {
                SparseTensorType[] sparseTensorTypeArr = values;
                if (i3 < sparseTensorTypeArr.length) {
                    return sparseTensorTypeArr[i3];
                }
            }
            return UNDEFINED;
        }
    }

    public OnnxSparseTensor(long j2, long j3, int i3, TensorInfo tensorInfo) {
        this(j2, j3, SparseTensorType.mapFromInt(i3), tensorInfo, null, null, null);
    }

    public OnnxSparseTensor(long j2, long j3, SparseTensorType sparseTensorType, TensorInfo tensorInfo, Buffer buffer, Buffer buffer2) {
        this(j2, j3, sparseTensorType, tensorInfo, buffer, null, buffer2);
    }

    public OnnxSparseTensor(long j2, long j3, SparseTensorType sparseTensorType, TensorInfo tensorInfo, Buffer buffer, LongBuffer longBuffer, Buffer buffer2) {
        super(j2, j3, tensorInfo);
        this.sparseTensorType = sparseTensorType;
        this.indices = buffer;
        this.innerIndices = longBuffer;
        this.values = buffer2;
    }

    private native void close(long j2, long j3);

    private static native long createCSRCSparseTensorFromBuffer(long j2, long j3, Buffer buffer, int i3, long j5, Buffer buffer2, int i5, long j6, Buffer buffer3, int i6, long[] jArr, long[] jArr2, int i7);

    public static <T extends Buffer> OnnxSparseTensor createSparseTensor(OrtEnvironment ortEnvironment, SparseTensor<T> sparseTensor) {
        return createSparseTensor(ortEnvironment, ortEnvironment.defaultAllocator, sparseTensor);
    }

    public static <T extends Buffer> OnnxSparseTensor createSparseTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, SparseTensor<T> sparseTensor) {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxSparseTensor on a closed OrtAllocator.");
        }
        TensorInfo constructFromSparseTensor = TensorInfo.constructFromSparseTensor(sparseTensor);
        OnnxJavaType indicesType = sparseTensor.getIndicesType();
        OrtUtil.BufferTuple prepareBuffer = OrtUtil.prepareBuffer(sparseTensor.getIndices(), indicesType);
        OrtUtil.BufferTuple prepareBuffer2 = OrtUtil.prepareBuffer(sparseTensor.getValues(), constructFromSparseTensor.type);
        Buffer buffer = prepareBuffer.data;
        if (!(buffer instanceof LongBuffer) && !(buffer instanceof IntBuffer)) {
            throw new IllegalStateException("Unexpected type of indices buffer, found " + prepareBuffer.data.getClass() + ", expected IntBuffer or LongBuffer");
        }
        int i3 = AnonymousClass1.$SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[sparseTensor.getSparsityType().ordinal()];
        if (i3 == 1 || i3 == 2) {
            return new OnnxSparseTensor(createSparseTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, prepareBuffer.data, prepareBuffer.pos, prepareBuffer.size, prepareBuffer2.data, prepareBuffer2.pos, constructFromSparseTensor.shape, sparseTensor.getIndicesShape(), sparseTensor.getValuesShape(), constructFromSparseTensor.onnxType.value, sparseTensor.getSparsityType().value), ortAllocator.handle, sparseTensor.getSparsityType(), constructFromSparseTensor, prepareBuffer.data, prepareBuffer2.data);
        }
        if (i3 != 3) {
            throw new IllegalArgumentException("Cannot create an UNDEFINED sparse tensor.");
        }
        OrtUtil.BufferTuple prepareBuffer3 = OrtUtil.prepareBuffer(((CSRCTensor) sparseTensor).getInnerIndices(), indicesType);
        return new OnnxSparseTensor(createCSRCSparseTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, prepareBuffer.data, prepareBuffer.pos, prepareBuffer.size, prepareBuffer3.data, prepareBuffer3.pos, prepareBuffer3.size, prepareBuffer2.data, prepareBuffer2.pos, constructFromSparseTensor.shape, sparseTensor.getValuesShape(), constructFromSparseTensor.onnxType.value), ortAllocator.handle, sparseTensor.getSparsityType(), constructFromSparseTensor, prepareBuffer.data, (LongBuffer) prepareBuffer3.data, prepareBuffer2.data);
    }

    private static native long createSparseTensorFromBuffer(long j2, long j3, Buffer buffer, int i3, long j5, Buffer buffer2, int i5, long[] jArr, long[] jArr2, long[] jArr3, int i6, int i7);

    private native ByteBuffer getIndicesBuffer(long j2, long j3);

    private native long[] getIndicesShape(long j2, long j3);

    private native ByteBuffer getInnerIndicesBuffer(long j2, long j3);

    private native long[] getInnerIndicesShape(long j2, long j3);

    private native ByteBuffer getValuesBuffer(long j2, long j3);

    private native long[] getValuesShape(long j2, long j3);

    @Override // ai.onnxruntime.OnnxValue, java.lang.AutoCloseable
    public void close() {
        close(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    public Buffer getIndicesBuffer() {
        int i3 = AnonymousClass1.$SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[this.sparseTensorType.ordinal()];
        if (i3 != 1) {
            if (i3 == 2) {
                IntBuffer asIntBuffer = getIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asIntBuffer();
                IntBuffer allocate = IntBuffer.allocate(asIntBuffer.capacity());
                allocate.put(asIntBuffer);
                allocate.rewind();
                return allocate;
            }
            if (i3 != 3) {
                throw new IllegalStateException("UNDEFINED sparse tensor type.");
            }
        }
        LongBuffer asLongBuffer = getIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asLongBuffer();
        LongBuffer allocate2 = LongBuffer.allocate(asLongBuffer.capacity());
        allocate2.put(asLongBuffer);
        allocate2.rewind();
        return allocate2;
    }

    public long[] getIndicesShape() {
        return getIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    public LongBuffer getInnerIndicesBuffer() {
        if (this.sparseTensorType != SparseTensorType.CSRC) {
            throw new IllegalStateException("Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + this.sparseTensorType);
        }
        LongBuffer asLongBuffer = getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asLongBuffer();
        LongBuffer allocate = LongBuffer.allocate(asLongBuffer.capacity());
        allocate.put(asLongBuffer);
        allocate.rewind();
        return allocate;
    }

    public long[] getInnerIndicesShape() {
        if (this.sparseTensorType == SparseTensorType.CSRC) {
            return getInnerIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
        }
        throw new IllegalStateException("Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + this.sparseTensorType);
    }

    public SparseTensorType getSparseTensorType() {
        return this.sparseTensorType;
    }

    @Override // ai.onnxruntime.OnnxValue
    public OnnxValue.OnnxValueType getType() {
        return OnnxValue.OnnxValueType.ONNX_TYPE_SPARSETENSOR;
    }

    @Override // ai.onnxruntime.OnnxValue
    public SparseTensor<? extends Buffer> getValue() {
        Buffer valuesBuffer = getValuesBuffer();
        long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
        int i3 = AnonymousClass1.$SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[this.sparseTensorType.ordinal()];
        if (i3 == 1) {
            LongBuffer longBuffer = (LongBuffer) getIndicesBuffer();
            TensorInfo tensorInfo = this.info;
            return new COOTensor(longBuffer, indicesShape, valuesBuffer, tensorInfo.shape, tensorInfo.type, valuesBuffer.remaining());
        }
        if (i3 == 2) {
            long[] valuesShape = getValuesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
            IntBuffer intBuffer = (IntBuffer) getIndicesBuffer();
            TensorInfo tensorInfo2 = this.info;
            return new BlockSparseTensor(intBuffer, indicesShape, valuesBuffer, valuesShape, tensorInfo2.shape, tensorInfo2.type, valuesBuffer.remaining());
        }
        if (i3 != 3) {
            throw new IllegalStateException("Undefined sparsity type in this sparse tensor.");
        }
        LongBuffer longBuffer2 = (LongBuffer) getIndicesBuffer();
        LongBuffer innerIndicesBuffer = getInnerIndicesBuffer();
        TensorInfo tensorInfo3 = this.info;
        return new CSRCTensor(longBuffer2, innerIndicesBuffer, valuesBuffer, tensorInfo3.shape, tensorInfo3.type, valuesBuffer.remaining());
    }

    public Buffer getValuesBuffer() {
        ByteBuffer order = getValuesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder());
        switch (AnonymousClass1.$SwitchMap$ai$onnxruntime$OnnxJavaType[this.info.type.ordinal()]) {
            case 1:
                TensorInfo.OnnxTensorType onnxTensorType = this.info.onnxType;
                if (onnxTensorType != TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
                    if (onnxTensorType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) {
                        throw new IllegalArgumentException("BFloat16 is not supported.");
                    }
                    FloatBuffer asFloatBuffer = order.asFloatBuffer();
                    FloatBuffer allocate = FloatBuffer.allocate(asFloatBuffer.capacity());
                    allocate.put(asFloatBuffer);
                    allocate.rewind();
                    return allocate;
                }
                ShortBuffer asShortBuffer = order.asShortBuffer();
                int capacity = asShortBuffer.capacity();
                FloatBuffer allocate2 = FloatBuffer.allocate(capacity);
                for (int i3 = 0; i3 < capacity; i3++) {
                    allocate2.put(OnnxTensor.fp16ToFloat(asShortBuffer.get(i3)));
                }
                allocate2.rewind();
                return allocate2;
            case 2:
                DoubleBuffer asDoubleBuffer = order.asDoubleBuffer();
                DoubleBuffer allocate3 = DoubleBuffer.allocate(asDoubleBuffer.capacity());
                allocate3.put(asDoubleBuffer);
                allocate3.rewind();
                return allocate3;
            case 3:
                ShortBuffer asShortBuffer2 = order.asShortBuffer();
                ShortBuffer allocate4 = ShortBuffer.allocate(asShortBuffer2.capacity());
                allocate4.put(asShortBuffer2);
                allocate4.rewind();
                return allocate4;
            case 4:
                IntBuffer asIntBuffer = order.asIntBuffer();
                IntBuffer allocate5 = IntBuffer.allocate(asIntBuffer.capacity());
                allocate5.put(asIntBuffer);
                allocate5.rewind();
                return allocate5;
            case 5:
                LongBuffer asLongBuffer = order.asLongBuffer();
                LongBuffer allocate6 = LongBuffer.allocate(asLongBuffer.capacity());
                allocate6.put(asLongBuffer);
                allocate6.rewind();
                return allocate6;
            case 6:
            case 7:
            case 8:
                ByteBuffer allocate7 = ByteBuffer.allocate(order.capacity());
                allocate7.put(order);
                allocate7.rewind();
                return allocate7;
            case 9:
                throw new IllegalStateException("Unsupported data type String");
            default:
                throw new IllegalStateException("Unsupported data type");
        }
    }

    public long[] getValuesShape() {
        return getValuesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }
}
